












import time
import hmac
import hashlib
import json
import requests
from quant.markets import Functions
from quant.markets.functions import get_asset_value
from quant.accounts import Order
from quant.utils import logging, Timer2, catch_exception, LimitDict, Iota
from quant.const import *

from quant.exchanges.util import (spot_value_factor, RecentChecker2, AscendingChecker, update_balance_default, prepare_placing,
                                  prepare_canceling, simulate_deal_spot, int_prec_to_precision)
from quant.exchanges.basics import ChannelBasic, HandlerBasic, ApiBasic, SpiBasic
from quant.exchanges.gate_util import GateSocket, format_symbol, resume_symbol, is_welcome_data, build_req


class GateSpotChannel(ChannelBasic):
    def init(self):
        self.socket = GateSocket(ws_host, self.on_open, self.on_message, self.on_close)

    def subscribe_book(self, ws):
        freq = '100ms'
        if self.frequency == DataFrequency.Low:
            freq = '1000ms'

        event = {
            "time": int(time.time()),
            "channel": "spot.order_book_update",
            "event": "subscribe",
            "payload": [format_symbol(self.symbol), freq],
        }
        msg = json.dumps(event)
        ws.send(msg)

    def subscribe_trade(self, ws):
        event = {
            "time": int(time.time()),
            "channel": 'spot.trades',
            "event": "subscribe",
            "payload": [format_symbol(self.symbol)],
        }
        msg = json.dumps(event)
        ws.send(msg)

    def subscribe_ticker(self, ws):
        raise NotImplementedError('Ticker for Gate not implemented')


class GateSpotHandler(HandlerBasic):
    def init(self):
        self.data_id_checker = RecentChecker2(50)
        self.data_id_checker_2 = AscendingChecker()

    def process_book(self, routing_key, recv_time, raw):
        if raw == NAME_CHANNEL_CLOSE:
            return self.process_close(routing_key, recv_time)

        data = json.loads(raw)
        if is_welcome_data(data):
            return self.info_welcome()

        data = data['result']
        server_id = data['u']
        server_time = data['t'] / 1000
        self.markets.info_engine.push_process_recv(server_id, server_time, recv_time)

        if not self.data_id_checker.is_new(server_id):
            return

        bids = data['b']
        asks = data['a']
        book = self.book

        for p, v in bids:
            book['buy', float(p)] = float(v)
        for p, v in asks:
            book['sell', float(p)] = float(v)

        self.check_book(book)
        self.push_book(routing_key, recv_time, server_time, server_id, book)

    def process_trade(self, routing_key, recv_time, raw):
        if raw == NAME_CHANNEL_CLOSE:
            return self.process_close(routing_key, recv_time)

        data = json.loads(raw)
        if is_welcome_data(data):
            return self.info_welcome()

        data = data['result']
        server_id = data['id']
        server_time = float(data['create_time_ms']) / 1000
        self.markets.info_engine.push_process_recv(server_id, server_time, recv_time)

        if not self.data_id_checker.is_new(server_id):
            return

        trade_list = [[data['side'], float(data['price']), float(data['amount'])]]
        self.push_trade(routing_key, recv_time, server_time, server_id, trade_list)

    def process_ticker(self, routing_key, recv_time, raw):
        if raw == NAME_CHANNEL_CLOSE:
            return self.process_close(routing_key, recv_time)

        j = json.loads(raw)
        server_id = j['u']
        self.markets.info_engine.push_process_recv(server_id, None, recv_time)

        if not self.data_id_checker_2.is_new(server_id):
            return

        ticker = (
            (float(j['b']), float(j['B'])),
            (float(j['a']), float(j['A'])),
        )
        self.push_ticker(routing_key, recv_time, None, server_id, ticker)


class GateSpotApi(ApiBasic):
    def query_balance(self, symbol=None):
        path = '/api/v4/spot/accounts'
        req = build_req('GET', rest_host, path, self.api_key, self.secret_key, None)
        self.requesting.request_async(symbol, req, self._on_balance)

    def query_all_balance(self):
        path = '/api/v4/spot/accounts'
        req = build_req('GET', rest_host, path, self.api_key, self.secret_key, None)
        self.requesting.request_async(None, req, self._on_all_balance)

    def query_all_position(self):
        return

    def query_all_margin(self):
        return

    def query_open_orders(self, symbol=None):
        path = '/api/v4/spot/open_orders'
        req = build_req('GET', rest_host, path, self.api_key, self.secret_key)
        self.requesting.request_async(symbol, req, self._on_open_orders)

    def query_all_open_orders(self):
        path = '/api/v4/spot/open_orders'
        req = build_req('GET', rest_host, path, self.api_key, self.secret_key)
        self.requesting.request_async(None, req, self._on_all_open_orders)

    def place_order(self, order):  # todo market_order
        order = prepare_placing(order, self._create_order_text())

        path = '/api/v4/spot/orders'
        param = {
            'text': order.client_id,
            'currency_pair': format_symbol(order.symbol),
            'side': order.side,
            'price': str(order.price),
            'amount': str(order.amount),
            'type': 'limit',
            'time_in_force': 'gtc',
            'account': 'spot',
        }

        type_ = order.order_type
        if type_ == OrderType.PostOnly:
            param['time_in_force'] = 'poc'

        elif type_ == OrderType.Limit:
            pass
        else:
            err = 'OrderType {} not implemented'.format(type_)
            raise NotImplementedError(err)

        req = build_req('POST', rest_host, path, self.api_key, self.secret_key, param)
        self.account.order_processor.process_place_operation(order)
        self.requesting.request_async(order, req, self._on_place)
        return order.client_id

    def cancel_order(self, order):
        order = prepare_canceling(order)
        order_id = order.order_id

        if order_id is None:
            return logging.error('cancel order without oid: {}'.format(order))

        path = '/api/v4/spot/orders/{}'.format(order_id)
        param = {
            'currency_pair': format_symbol(order.symbol),
            'order_id': order_id,
            'account': 'spot',
        }

        req = build_req('DELETE', rest_host, path, self.api_key, self.secret_key, param)
        self.account.order_processor.process_cancel_operation(order)
        self.requesting.request_async(order, req, self._on_cancel)

    def _on_balance(self, symbol, response):
        if int(response.status_code) // 100 != 2:
            return self._fail_request('query_balance', symbol, response)

        if symbol is None:
            interests = self._asset_added
        else:
            interests = {*symbol.split('/')}

        j = response.json()
        balance = self.account.balance

        for dic in j:
            asset = dic['currency'].lower()
            if asset in interests:
                update_balance_default(balance, asset, float(dic['available']), float(dic['locked']))
        self._put_data(UserEvent.Balance, balance)

    def _on_all_balance(self, _, response):
        if int(response.status_code) // 100 != 2:
            return self._fail_request('query_all_balance', _, response)

        j = response.json()
        balance = {}

        for dic in j:
            asset = dic['currency'].lower()
            free = float(dic['available'])
            frozen = float(dic['locked'])
            if free or frozen:
                update_balance_default(balance, asset, free, frozen)

        self.account.balance.clear()
        self.account.balance.update(balance)

        self._put_data(UserEvent.Balance, balance)

    def _on_open_orders(self, symbol, response):
        if int(response.status_code) // 100 != 2:
            return self._fail_request('query_open_orders', symbol, response)

        if symbol is None:
            interest = self._symbol_added
        else:
            interest = {symbol}

        open_orders = []
        for dic in response.json():

            s = resume_symbol(dic['currency_pair'])
            if s not in interest:
                continue

            for d in dic['orders']:
                cid = d['text']
                oid = d['id']
                cid = check_cid(cid, oid)

                order = Order(
                    symbol=s,
                    side=d['side'].lower(),
                    price=float(d['price']),
                    amount=float(d['left']),
                    client_id=cid,  # 作为template时要注意，有的交易所不会赋值手动单的cid.
                    order_id=oid,
                    status=OrderStatus.Pending,
                )
                open_orders.append(order)

        self.account.order_processor.process_open_orders(open_orders)
        self._put_data(UserEvent.OpenOrders, self.account.orders)

    def _on_all_open_orders(self, param, response):
        if int(response.status_code) // 100 != 2:
            return self._fail_request('query_open_orders', '', response)

        open_orders = []
        for dic in response.json():
            s = resume_symbol(dic['currency_pair'])
            for d in dic['orders']:
                cid = d['text']
                oid = d['id']
                cid = check_cid(cid, oid)

                order = Order(
                    symbol=s,
                    side=d['side'].lower(),
                    price=float(d['price']),
                    amount=float(d['left']),
                    client_id=cid,  # 作为template时要注意，有的交易所不会赋值手动单的cid.
                    order_id=oid,
                    status=OrderStatus.Pending,
                )
                open_orders.append(order)

        if open_orders:
            self.account.order_processor.process_open_orders(open_orders)
            self._put_data(UserEvent.OpenOrders, self.account.orders)

    def _on_place(self, order, response):
        order = order.copy()
        j = response.json()

        if int(response.status_code) // 100 != 2:
            self._fail_request('place_order', order, response)
            order.status = OrderStatus.PlaceFailed

        elif 'id' not in j:
            self._fail_request('place_order', order, response)
            order.status = OrderStatus.PlaceFailed

        else:
            order.order_id = j['id']
            order.status = OrderStatus.Pending

            if order.order_type == OrderType.Market:
                order.status = OrderStatus.FullyFilled

        order = self.account.order_processor.update_order(order)
        self._put_data(UserEvent.Order, order)

    def _on_cancel(self, order, response):
        order = order.copy()

        if int(response.status_code) // 100 != 2:
            self._fail_request('cancel_order', order, response)

        if int(response.status_code) == -1:
            order.status = OrderStatus.CancelFailed
        else:
            order.status = OrderStatus.Canceled

        order = self.account.order_processor.update_order(order)
        self._put_data(UserEvent.Order, order)

    def _create_order_text(self):
        text = 't-{}'.format(self._create_cid())
        return text


class GateSpotSpi(SpiBasic):
    _listen_key = ''
    _extend_key_started = False

    def add_symbol(self, symbol):
        super().add_symbol(symbol)
        self.connect_once()
        if self._socket.is_connected():
            self._subscribe_symbols([symbol])

    def _create_socket(self):
        socket = GateSocket(ws_host, self._on_open, self._on_message, self._on_close)
        return socket

    def _on_open(self, ws):
        self._request('spot.balances', 'subscribe')
        self._subscribe_symbols(list(self._symbol_added))

    def _subscribe_symbols(self, symbols):
        symbols = [format_symbol(s) for s in symbols]
        self._request('spot.orders', 'subscribe', symbols)
        # self._request('spot.usertrades', 'subscribe', symbols)

    def _request(self, channel, event, payload=None):
        current_time = int(time.time())
        data = {
            "time": current_time,
            "channel": channel,
            "event": event,
        }
        if payload is not None:
            data['payload'] = payload

        message = 'channel=%s&event=%s&time=%d' % (channel, event, current_time)
        h = hmac.new(self.secret_key.encode("utf8"), message.encode("utf8"), hashlib.sha512)
        sign = h.hexdigest()

        data['auth'] = {
            "method": "api_key",
            "KEY": self.api_key,
            "SIGN": sign,
        }
        data = json.dumps(data)
        self._socket.assert_send(data, 5)

    def _on_message(self, ws, message):
        j = json.loads(message)
        channel = j['channel']

        if j['event'] == 'subscribe':
            return

        if channel == 'spot.balances':
            self._handle_balance(j)
        elif channel == 'spot.orders':
            self._handle_order(j)
        # elif channel == 'spot.usertrades':
        #     self._handle_deal(j)
        else:
            logging.warn('GateSpotSpi unknown msg:{}'.format(message))

    def _handle_balance(self, data):
        balance = self.account.balance
        result = data['result']

        for d in result:
            free = float(d['available'])
            frozen = float(d['total']) - free
            asset = d['currency'].lower()
            update_balance_default(balance, asset, free, frozen)
        self._put_data(UserEvent.Balance, balance)

    def _handle_order(self, data):
        for d in data['result']:
            cid = d['text']
            oid = d['id']
            cid = check_cid(cid, oid)

            left = float(d['left'])
            event = d['event']
            if event == 'put':
                status = OrderStatus.Pending
            elif event == 'finish':
                if left == 0:
                    status = OrderStatus.FullyFilled
                else:
                    status = OrderStatus.Canceled
            elif event == 'update':
                status = OrderStatus.PartFilled
            else:
                err = 'Unknown order status:{}'.format(event)
                raise Exception(err)

            order = Order(
                symbol=resume_symbol(d['currency_pair']),
                side=d['side'].lower(),
                price=float(d['price']),
                amount=left,
                order_id=oid,
                client_id=cid,
                status=status,
            )

            deal = 0
            acc_ord = self.account.orders.get(cid)
            if acc_ord is not None:
                acc_amt = acc_ord.amount
                if acc_amt is not None:
                    deal = acc_amt - left

            order = self.account.order_processor.update_order(order)
            self._put_data(UserEvent.Order, order)

            if deal != 0:
                self._put_deal(order, deal)

    # def _handle_deal(self, data):
    #     for d in data['result']:
    #         cid = d['text']
    #         oid = d['id']
    #         cid = check_cid(cid, oid)
    #
    #         deal = float(d['amount'])
    #
    #         order = Order(
    #             symbol=resume_symbol(d['currency_pair']),
    #             side=d['side'].lower(),
    #             price=float(d['price']),
    #             order_id=oid,
    #             client_id=cid,
    #         )
    #
    #         self._put_deal(order, deal)


class GateSpotFunctions(Functions):
    def get_all_ticker(self):
        path = '/api/v4/spot/tickers'
        url = rest_host + path

        response = self.session.get(url, timeout=3)
        if int(response.status_code) // 100 != 2:
            return {}
        j = response.json()

        result = {}
        _usdt = get_asset_value('usdt')

        for d in j:
            symbol = resume_symbol(d['currency_pair'])
            if not symbol.endswith('usdt'):
                continue

            bid = d['highest_bid']
            ask = d['lowest_ask']
            if bid == '' or ask == '':
                continue

            result[symbol] = {
                'price': float(d['last']),
                'vol': float(d['quote_volume']) * _usdt,
                'bid': float(bid),
                'ask': float(ask),
            }
        return result

    def get_recent_trade(self, symbol, limit=1000):
        path = '/api/v4/spot/trades'
        url = '{}{}?currency_pair={}&limit={}'.format(rest_host, path, format_symbol(symbol), limit)

        response = self.session.get(url, timeout=3)
        if int(response.status_code) // 100 != 2:
            return {}
        j = response.json()

        result = [
            (
                d['id'],
                d['side'],
                float(d['price']),
                float(d['amount']),
                float(d['create_time_ms']) / 1000,
            )
            for d in j
        ]
        result.reverse()

        return result

    def get_all_precision(self):
        url = 'https://api.gateio.ws/api/v4/spot/currency_pairs'
        response = self.session.get(url, timeout=3)
        if int(response.status_code) // 100 != 2:
            return {}

        result = {}
        for d in response.json():
            symbol = resume_symbol(d['id'])
            if not symbol.endswith('/usdt'):
                continue
            price_unit = int_prec_to_precision(int(d['precision']))
            amount_unit = int_prec_to_precision(int(d['amount_precision']))

            result[symbol] = (price_unit, amount_unit)
        return result

    def get_value_factor(self, symbol):
        return spot_value_factor(symbol)

    def get_all_contract_size(self):
        all_ticker = self.get_all_ticker()
        result = {s: 1 for s in all_ticker}
        return result

    def simulate_deal(self, order, deal_amount, mid_price, account):
        return simulate_deal_spot(order, deal_amount, mid_price, account)


def check_cid(cid, oid):
    if not cid.startswith('t-'):
        if oid in cid_limit_dict:
            cid = cid_limit_dict[oid]
        else:
            cid = 'user-{}'.format(cid_iota.next())
            cid_limit_dict[oid] = cid
    return cid


cid_limit_dict = LimitDict(100)
cid_iota = Iota()

rest_host = 'https://api.gateio.ws'
ws_host = 'wss://api.gateio.ws/ws/v4/'


if __name__ == '__main__':
    import utils
    from quant.utils import perf_intv
    import keys
    from quant.markets import Markets, functions
    from quant.accounts import Accounts, Order
    from quant.utils import set_test_mode
    from utils import feed_test_raw, print_req
    set_test_mode()

    def try_channel():
        s = 'axs/usdt'
        mar = Markets()
        mar.add_market(Event.Book, Exchange.Gate, s, DataFrequency.Normal)
        mar.add_market(Event.Trade, Exchange.Gate, s, DataFrequency.Normal)

        # utils.print_book(mar)
        # utils.print_raw(mar)
        # utils.print_trades(mar)
        # utils.perf_raw_to_data(mar)
        # utils.check_raw_to_data_duplicate(mar)

        while True:
            input(':')

    def try_api():
        key = keys.get_key('gate_9293')
        account = Accounts().create_account('Gate', key)
        api = account.api

        # account.add_symbol('btc/usdt')
        account.info_engine.show_user_data()
        account.info_engine.show_problems()

        # api.query_balance()
        # api.query_all_balance()
        # api.query_open_orders('eth/usdt')
        # api.query_all_open_orders()

        # order = Order('eth/usdt', 'buy', 1, 1)
        # order = Order('doge/usdt', 'buy', 0.05, 21, order_type=OrderType.PostOnly)
        # api.place_order(order)

        api.add_symbol('doge/usdt')
        api.query_open_orders()
        api.query_open_orders()
        api.query_open_orders()
        api.query_open_orders()

        api.query_all_open_orders()
        api.query_all_open_orders()
        api.query_all_open_orders()
        api.query_all_open_orders()
        # input('wait order:')
        # for o in list(account.orders.values()):
        #     input('cancel:')
        #     api.cancel_order(o)

        # spi = GateSpotSpi(key)
        # spi.add_symbol('doge/usdt')

        # def on_ud(ud):
        #     print(ud)
        # account.subscribe(UserEvent.Order, on_ud)

        # def on_deal(order, amount):
        #     print('deal amount:', amount)
        # info = spi.account.info_engine
        # info.subscribe(info.Deal, on_deal)

        while True:
            input(':')

    def try_spi():
        key = keys.get_key('gate_9293')
        spi = GateSpotSpi(key)
        acc = spi.account
        # acc.info_engine.show_user_data()
        acc.info_engine.show_problems()

        def on_ud(ud):
            print(ud.data)
        acc.subscribe(UserEvent.Order, on_ud)

        acc.info_engine.show_deal()

        spi.connect_once()
        spi.add_symbol('doge/usdt')

        while True:
            print('-----------orders-----------')
            for i in acc.orders.values():
                print(i)
            input(':')

    # try_spi()

    func = GateSpotFunctions()
    a = func.get_all_ticker()
    a = [(i, j) for i, j in a.items()]
    a.sort(key=lambda x: x[1]['vol'], reverse=True)
    for i, j in enumerate(a):
        print(i, j)

    # a = func.get_recent_trade('btc/usdt')
    # for i in a:
    #     print(i)

    # print("len:", len(a))
